What do the pdf updates look like?¶
Exploring how the input distribution changes under Fisher information updates.
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
from IPython.core.pylabtools import figsize
import seaborn as sns
import plotly.express as px
import numpy as np
import pandas as pd
import polars as pl
import statsmodels.formula.api as smf
import statsmodels.api as sm
import matplotlib
import torch
from discriminationAnalysis import Fisher_smooth_fits
from basicModel import EstimateAngle
from adaptableModel import AdaptableEstimator, AngleDistribution
from adapt_fit_loop import moving_average
pdf update for the concentrated case¶
import glob
fig, ax = plt.subplots(subplot_kw={'projection': 'polar'})
plt.title('Fisher Information: trained networks')
ex4_conc_dir = 'trainedParameters/Exp4_conc/'
FIcurves_conc = []
for rep in range(6):
trained_ckpt = glob.glob(ex4_conc_dir + f'rep{rep}/epoch*')[0]
model = EstimateAngle.load_from_checkpoint(trained_ckpt)
fi = Fisher_smooth_fits(model, 0., np.pi, N_cov=500, Samp_cov=500)
FIcurves_conc.append(fi)
plt.plot(np.linspace(0, 2*np.pi, 500), fi)
np.array(FIcurves_conc).min(1)
array([14.69482853, 9.7832375 , 0.17993534, 17.95820599, 25.27543399,
18.58125288])
fisher_curves = np.array(FIcurves_conc)
smoothed_mean_fisher = moving_average(np.mean(fisher_curves, axis=0))
fig, ax = plt.subplots(subplot_kw={'projection': 'polar'})
plt.plot(np.linspace(0, 2*np.pi, 500), smoothed_mean_fisher)
[<matplotlib.lines.Line2D at 0x13bdcaeb0>]
unif = np.ones(500)
p1 = unif / smoothed_mean_fisher**0.5
p1 = p1 / p1.sum()
fig, ax = plt.subplots(subplot_kw={'projection': 'polar'})
plt.plot(np.linspace(0, 2*np.pi, 500), p1, 'k')
[<matplotlib.lines.Line2D at 0x13becb1c0>]
Experiment 4 concentrated was trained with the data concentrated around pi /2.
Note for these plots, I am doubling the angular scale for ease of visualization of the periodic signal. Thus, angles close to zero are, in fact, orthogonal to the pi/2 angles.
What is the idea of the iteration?
- this network has its Fisher information concentrated at pi/2.
- fine-tuning on the reverse, concentrated around 0, should, thus, remove this sensitivity bias.
Finetuning what? Should we finetune this network (the already trained one), or the previous network that we trained in order to get this response to the current stimulus distribution?
Note for the record:¶
These are seperately trained versions of the network that we are averaging over to determine the Fisher information.
fig, ax = plt.subplots(subplot_kw={'projection': 'polar'})
for fi in FIcurves_conc:
c = unif / fi**0.5
plt.plot(np.linspace(0, 2*np.pi, 500), c / c.sum())
The individual runs can produce very noisy results.
Ideas as to why the convergence fails:¶
- There are modes that are amplified, rather than damped in the fitting process.
- essentially this is the case if the distribution above is more extreme than the distribution that the network was trained on.
from scipy.stats import vonmises
from scipy.integrate import trapezoid
fig, ax = plt.subplots(subplot_kw={'projection': 'polar'})
x = np.linspace(0, 2*np.pi, 500)
p1 = unif / smoothed_mean_fisher**0.5
p1 = p1 / trapezoid(p1, x)
plt.plot(x, vonmises(8., 0).pdf( x))
plt.plot(x, p1, 'k')
[<matplotlib.lines.Line2D at 0x13bfa2e80>]
I have to be careful here about how these are normalized, and to what x domain. However, the fact that the values near zero are 'rounded out' is hopeful.
from adaptableModel import AngleDistribution
# this is the former default
a= AngleDistribution(p1, [-np.pi, np.pi])
update_samples = a.sample(10000)
prior_samples = vonmises(8., 0).rvs( 10000)
n_update, b_update = np.histogram(update_samples, bins=50, density=True)
plt.plot(b_update[1:], n_update)
n_prior, b_prior = np.histogram(prior_samples, bins=50, density=True)
plt.plot(b_prior[1:], n_prior, 'k')
[<matplotlib.lines.Line2D at 0x13c0180a0>]
Ok, this is reasonable good to show the distribution is, in fact getting more uniform, at least for this one step.
test = AngleDistribution( vonmises(8., 0).pdf( x), [-np.pi, np.pi])
test_num, test_bins = np.histogram(test.sample(10000), bins=50, density=True)
plt.plot(test_bins[1:], test_num)
plt.plot(b_prior[1:], n_prior, 'k')
[<matplotlib.lines.Line2D at 0x13c09f340>]
These are possible issues:¶
- AngleDistribution appears to rotate the angles by 180 degrees
- AngleDistribution also uses angles betwee -pi and pi rather that 0 to pi (or some other half length parameterization), which is inconsistent with the generation.
test_num, test_bins = np.histogram(test.sample(10000), bins=50, density=True)
rotated_num, rotated_bins = np.histogram( vonmises(8.,np.pi).rvs(10000), bins=50, density=True)
plt.plot(test_bins[1:], test_num)
plt.plot(rotated_bins[1:], rotated_num, '--k')
[<matplotlib.lines.Line2D at 0x13c11f550>]
The density itself seems reasonably unchanged.
What do the iterates look like?¶
from adaptableModel import AngleDistribution
data =pd.read_pickle('trainedParameters/Exp6/untrained/iterate_data.pickle')
colors = plt.cm.viridis(np.linspace(0,1,8))
for row in data[ data.measurement=='probability'].to_dict(orient="records"):
itr = row['iteration']
if itr > 0:
itr = itr-1
dist = AngleDistribution(row['data'], [0, np.pi])
plt.plot(np.linspace(0, np.pi, dist.npoints-1), dist.bin_probs, c=colors[itr])
plt.title('Untrained Network Trajectory')
Text(0.5, 1.0, 'Untrained Network Trajectory')
data =pd.read_pickle('trainedParameters/Exp6/uniform/iterate_data.pickle')
colors = plt.cm.viridis(np.linspace(0,1,8))
for row in data[ data.measurement=='probability'].to_dict(orient="records"):
itr = row['iteration']
if itr > 0:
itr = itr-1
dist = AngleDistribution(row['data'], [0, np.pi])
plt.plot(np.linspace(0, np.pi, dist.npoints-1), dist.bin_probs, c=colors[itr])
plt.title('Uniform Network Trajectory')
Text(0.5, 1.0, 'Uniform Network Trajectory')
data =pd.read_pickle('trainedParameters/Exp6/concentrated/iterate_data.pickle')
colors = plt.cm.viridis(np.linspace(0,1,8))
for row in data[ data.measurement=='probability'].to_dict(orient="records"):
itr = row['iteration']
if itr > 0:
itr = itr-1
dist = AngleDistribution(row['data'], [0, np.pi])
plt.plot(np.linspace(0, np.pi, dist.npoints-1), dist.bin_probs, c=colors[itr])
plt.title('Concentrated Network Trajectory')
Text(0.5, 1.0, 'Concentrated Network Trajectory')
xs =np.linspace(-1, 1, 201)
for i in range(8):
plt.plot(xs, i* xs, c=colors[i])
The trend here is clear.¶
It seems that the small deviations in the probability distributions are being successively amplified through the process of finetuning and fitting the Fisher information.
This is a little surprising:
The concentrated network appears to have its biases pretty well removed by the initial fine-tuning, but the newly introduced small biases aren't trained away? How does that make sense?
concentrated_data =pd.read_pickle('trainedParameters/Exp6/concentrated/iterate_data.pickle')
raw = concentrated_data[(concentrated_data.measurement == 'FI') &
(concentrated_data.iteration == 0)]['data']
xs = np.linspace(0, np.pi,500)
for row in raw:
plt.plot(xs, row)
smoothed_mean_fisher = moving_average(np.mean(raw, axis=0))
plt.plot(xs, smoothed_mean_fisher, 'k')
[<matplotlib.lines.Line2D at 0x141221a30>]
recorded = concentrated_data[(concentrated_data.iteration == 2) &
(concentrated_data.measurement == 'probability')]['data'].iloc[0]
plt.plot(xs, 1./smoothed_mean_fisher**0.5)
plt.plot(xs, recorded, '--k')
[<matplotlib.lines.Line2D at 0x140ed6550>]
Ok, good double check: the probability distribution is calculated as expected.
interim summary:¶
This is a seemingly paradoxical result:
- In this case, $p(s_1) - p(s_2) > 0$ means that $p'(s_1) - p'(s_2) > p(s_1) - p(s_2)$. Thus, $p(s_1) / \sqrt{I(s_1)} - p(s_2) / \sqrt{I(s_2)} > p(s_1) - p(s_2)$...
Is there actually a negative correlation between the probability density that the networks were trained on and the Fisher information in the resulting networks?¶
What if we plot the pdf that we fine-tuned on vs the Fisher information measurements?
concentrated_data = pd.read_pickle('trainedParameters/Exp6/concentrated/iterate_data.pickle')
uniform_data = pd.read_pickle('trainedParameters/Exp6/uniform/iterate_data.pickle')
untrained_data = pd.read_pickle('trainedParameters/Exp6/untrained/iterate_data.pickle')
# fix the indexing issues in the experiment script
inds =concentrated_data[ concentrated_data['measurement'] == 'probability'].index
concentrated_data.loc[inds, 'iteration'] = range(8)
uniform_data.loc[inds, 'iteration'] = range(8)
untrained_data.loc[inds, 'iteration'] = range(8)
pt = concentrated_data.pivot_table(index='iteration', columns='measurement',
values='data', aggfunc='mean')
for row in pt.to_dict(orient="records")[::-1]:
dist = AngleDistribution(row['probability'], [0, np.pi])
plt.plot(dist.bin_probs, row['FI'][1:], '.')
plt.xlabel('probability')
plt.ylabel('Fisher Info')
Text(0, 0.5, 'Fisher Info')
pt = uniform_data.pivot_table(index='iteration', columns='measurement',
values='data', aggfunc='mean')
for row in pt.to_dict(orient="records")[::-1]:
dist = AngleDistribution(row['probability'], [0, np.pi])
plt.plot(dist.bin_probs, row['FI'][1:], '.')
plt.xlabel('probability')
plt.ylabel('Fisher Info')
Text(0, 0.5, 'Fisher Info')
pt = untrained_data.pivot_table(index='iteration', columns='measurement',
values='data', aggfunc='mean')
for row in pt.to_dict(orient="records")[::-1]:
dist = AngleDistribution(row['probability'], [0, np.pi])
plt.plot(dist.bin_probs, row['FI'][1:], '.')
plt.xlabel('probability')
plt.ylabel('Fisher Info')
Text(0, 0.5, 'Fisher Info')
Note that I'm plotting these in reverse order (so grey is the first iteration, then pink, etc).
As the iterations continue (and the probabilities diffuse outward), we see the the gradual emergence of an inverse correlation between probability in the training set and the Fisher information of the learned mapping.
This is the exact opposite of the case previously!
Did the previous case actually show what I thought?¶
import glob
fig, ax = plt.subplots(subplot_kw={'projection': 'polar'})
plt.title('Fisher Information: trained networks')
ex4_conc_dir = 'trainedParameters/Exp4_conc/'
FIcurves_conc = []
for rep in range(6):
trained_ckpt = glob.glob(ex4_conc_dir + f'rep{rep}/epoch*')[0]
model = EstimateAngle.load_from_checkpoint(trained_ckpt)
fi = Fisher_smooth_fits(model, 0., np.pi, N_cov=500, Samp_cov=500)
FIcurves_conc.append(fi)
print(rep)
plt.plot(np.linspace(0, 2*np.pi, 500), fi)
0 1
--------------------------------------------------------------------------- KeyboardInterrupt Traceback (most recent call last) Cell In[3], line 15 11 trained_ckpt = glob.glob(ex4_conc_dir + f'rep{rep}/epoch*')[0] 13 model = EstimateAngle.load_from_checkpoint(trained_ckpt) ---> 15 fi = Fisher_smooth_fits(model, 0., np.pi, N_cov=500, Samp_cov=500) 16 FIcurves_conc.append(fi) 18 print(rep) File ~/Documents/NNtraining/angleFineTuning/discriminationAnalysis.py:35, in Fisher_smooth_fits(model, theta_start, theta_end, N_mean, N_cov, Samp_cov) 33 FI = [] 34 for i, angle in enumerate(cov_angles): ---> 35 noisy_results = generate_samples(model, Samp_cov*[angle]) 36 invcov = np.linalg.inv(np.cov(noisy_results.T)) 38 FI.append(derivs[:, deriv_cov_ratio*i] @ invcov @ derivs[:, deriv_cov_ratio*i]) File ~/Documents/NNtraining/angleFineTuning/discriminationAnalysis.py:98, in generate_samples(model, thetas, pixelDim, shotNoise, noiseVar) 96 def generate_samples(model, thetas, pixelDim=101, shotNoise=0.8, noiseVar=20): 97 """ generate samples from the model """ ---> 98 samples = model.forward(generateGrating(thetas, pixelDim=pixelDim, 99 shotNoise=shotNoise, noiseVar=noiseVar) 100 ).detach().numpy() 101 return samples File ~/Documents/NNtraining/angleFineTuning/datageneration/stimulusGeneration.py:35, in generateGrating(thetas, frequency, pixelDim, shotNoise, noiseVar) 32 # add noise to the generated gratings 33 noiseLocations = binomial(1, shotNoise, 34 size=(len(thetas), pixelDim, pixelDim)) ---> 35 noiseMagnitude = normal(scale=noiseVar**0.5, 36 size=(len(thetas), pixelDim, pixelDim)) 38 Z = torch.clamp(Z + torch.tensor( 39 noiseLocations * noiseMagnitude, dtype=torch.float32 40 ), min=-1., max=1.) 42 r2 = X**2 + Y**2 KeyboardInterrupt:
from scipy.stats import vonmises
dist_exp4 = vonmises(8., np.pi/2)
samples= dist_exp4.rvs(10000)
plt.hist(samples, bins=20)
plt.hist(samples %np.pi, bins=20)
(array([ 6., 3., 33., 71., 184., 336., 685., 1171., 1462.,
1608., 1539., 1239., 805., 480., 224., 87., 35., 20.,
9., 3.]),
array([0.13851379, 0.28685553, 0.43519727, 0.58353902, 0.73188076,
0.8802225 , 1.02856424, 1.17690599, 1.32524773, 1.47358947,
1.62193121, 1.77027295, 1.9186147 , 2.06695644, 2.21529818,
2.36363992, 2.51198167, 2.66032341, 2.80866515, 2.95700689,
3.10534863]),
<BarContainer object of 20 artists>)
plt.plot(samples, samples % np.pi, '.')
[<matplotlib.lines.Line2D at 0x1406528b0>]
Ok, this is nice: our distributions are very concentrated, so the way that we projected the data down to (0, pi) doesn't really matter.
samples= dist_exp4.rvs(50000) % np.pi
count, bins = np.histogram(samples, bins=np.linspace(0, np.pi, 501), density=True)
for y in FIcurves_conc:
plt.plot(count, y, '.')
plt.xlabel('Probability density')
plt.ylabel('Fisher Information');
fig, ax = plt.subplots(subplot_kw={'projection': 'polar'})
for y in FIcurves_conc:
plt.plot(np.linspace(0, 2*np.pi, 500), y)
plt.plot(np.linspace(0, 2*np.pi, 500), 10000*count,'k')
[<matplotlib.lines.Line2D at 0x1423fa370>]
Yes, indeed. Here we see a very different dependence between the probability density of the training distribution and the Fisher information in the learned network.
High density -> high Fisher information, and vice versa.
It's good to have triple checked this.
Notes:¶
Thinking more about it, it is very strange to get such clean behavior. This is an average of multiple retrained samples, but it looks like there is basically no noise?
fig, ax = plt.subplots(2, 3, subplot_kw={'projection': 'polar'})
for ind in range(6):
for fi in concentrated_data[(concentrated_data.measurement == 'FI')
& (concentrated_data.iteration == ind)]['data']:
plt.subplot(2,3,ind+1)
plt.plot(np.linspace(0, 2*np.pi, 500), fi)
No, but seriously, how are these so damn similar???
ckpts = glob.glob('trainedParameters/Exp6/concentrated/iter0/*')
from adaptableModel import AdaptableEstimator, AngleDistribution
U = AngleDistribution(np.ones(500), [0., np.pi])
for ckpt in ckpts:
model = AdaptableEstimator.load_from_checkpoint(ckpt, angle_dist=U)
print( model.hparams.seed )
967369843898950914 15346541979810798859 12711906263299886879 14665669494729647154 16652516071571257433
fig, ax = plt.subplots(2, 3, subplot_kw={'projection': 'polar'})
for ind in range(6):
for fi in uniform_data[(uniform_data.measurement == 'FI')
& (uniform_data.iteration == ind)]['data']:
plt.subplot(2,3,ind+1)
plt.plot(np.linspace(0, 2*np.pi, 500), fi)
fig, ax = plt.subplots(2, 3, subplot_kw={'projection': 'polar'})
for ind in range(6):
for fi in untrained_data[(untrained_data.measurement == 'FI')
& (untrained_data.iteration == ind)]['data']:
plt.subplot(2,3,ind+1)
plt.plot(np.linspace(0, 2*np.pi, 500), fi)
Is it possibly overfitting?¶
The models that I'm using to assess the Fisher information aren't necessarily the best chekpoint, since I don't reload these from the file...
Indeed, running a quick check shows that final model parameters are not the optimal ones.
I can test this by comparing the recorded Fisher information to the fits to Fisher information.
concentrated_rerun = []
flat_dist = AngleDistribution(np.ones(500), [0., np.pi])
for iter in range(8):
files = glob.glob(f'trainedParameters/Exp6/concentrated/iter{iter}/*')
for checkpoint in files:
model = AdaptableEstimator.load_from_checkpoint(checkpoint,
angle_dist=flat_dist,
max_epochs=0
)
fi = Fisher_smooth_fits(model, 0., np.pi, N_mean=10000, N_cov=500, Samp_cov=500)
row = {'iteration': iter, 'data': fi}
concentrated_rerun.append(row)
print(iter)
0 1 2 3 4 5 6 7
concentrated_rerun = pd.DataFrame(concentrated_rerun)
concentrated_rerun.groupby('iteration').mean()
| data | |
|---|---|
| iteration | |
| 0 | [7621.458407643084, 7004.0889070300955, 6895.1... |
| 1 | [8126.678988741916, 8286.100369626003, 7994.64... |
| 2 | [7979.954776736768, 7580.665071036531, 7865.98... |
| 3 | [8077.784758112767, 7796.0586707988205, 8172.8... |
| 4 | [8362.651840080049, 8816.162014109344, 8063.81... |
| 5 | [8089.661711997667, 7962.0554153758085, 7504.7... |
| 6 | [8612.287292079103, 8410.9201702268, 8117.0090... |
| 7 | [8226.916994561674, 8726.015627317034, 7589.94... |
colors = plt.cm.viridis(np.linspace(0,1,8))
i =0
for row in concentrated_data[ concentrated_data.measurement == 'FI'
].groupby('iteration').agg({'data': 'mean'}).sort_values('iteration')['data']:
plt.plot(moving_average(row), c=colors[i])
i += 1
plt.title('Fisher information - online record')
Text(0.5, 1.0, 'Fisher information - online record')
colors = plt.cm.viridis(np.linspace(0,1,8))
i =0
for row in concentrated_rerun.groupby('iteration').mean().sort_values('iteration')['data']:
plt.plot(moving_average(row), c=colors[i])
i += 1
plt.title('Fisher Information - posthoc evaluation')
Text(0.5, 1.0, 'Fisher Information - posthoc evaluation')
There is some difference between the two sets of curves.
The differences across iterations don't look strong enough to produce the trend that we've observed in the previous plots.
colors = plt.cm.viridis(np.linspace(0,1,9))
dist = AngleDistribution( np.ones(500), [0., np.pi])
plt.plot(dist.bin_probs, c=colors[0])
i = 1
for row in concentrated_data[ concentrated_data.measurement == 'FI'
].groupby('iteration').agg({'data': 'mean'}).sort_values('iteration')['data']:
new_values = dist.values / moving_average(row)**0.5
dist = AngleDistribution(new_values, [0., np.pi])
plt.plot(dist.bin_probs, c=colors[i])
i += 1
plt.title('Probability distribution')
Text(0.5, 1.0, 'Probability distribution')
colors = plt.cm.viridis(np.linspace(0,1,9))
dist = AngleDistribution( np.ones(500), [0., np.pi])
plt.plot(dist.bin_probs, c=colors[0])
i = 1
for row in concentrated_rerun.groupby('iteration').mean().sort_values('iteration')['data']:
new_values = dist.values / moving_average(row)**0.5
dist = AngleDistribution(new_values, [0., np.pi])
plt.plot(dist.bin_probs, c=colors[i])
i += 1
Interesting result¶
Ok, so it looks like the thing that causes this divergence is actually the fact that the Fisher information deviates from uniform in the same characteristic way every time we refit the neural networks.
That is to say it's because the Fisher information curves all have the same large scale shape.
This means that we divide by the same thing every time, which in turn causes the amplification of deviations.
That also explains the very deterministic seeming nature of the deviations: it really is repeated division by the same curve.
The question is why its the same?¶
Why do the networks all share this structure??
- Is it the shared initialization?
- Is is a randomization failure?
- Is it an architecture failure?
- Is it a Fisher information fitting failure?
Seed sharing¶
for ckpt in glob.glob('trainedParameters/Exp6/concentrated/iter0/*'):
dist = AngleDistribution(np.ones(500), [0, np.pi])
model = AdaptableEstimator.load_from_checkpoint( ckpt, angle_dist=dist)
print(model.hparams.seed)
967369843898950914 15346541979810798859 12711906263299886879 14665669494729647154 16652516071571257433
for ckpt in glob.glob('trainedParameters/Exp6/concentrated/iter1/*'):
dist = AngleDistribution(np.ones(500), [0, np.pi])
model = AdaptableEstimator.load_from_checkpoint( ckpt, angle_dist=dist)
print(model.hparams.seed)
10666304192740685704 12834590379526135592 10890395108299221895 11197852224913400512 13072276661895679149
The seeds that we recorded are not the same. Is it possible that the seed is not being set?
import torch
model1 = AdaptableEstimator.load_from_checkpoint(ckpt, angle_dist=dist, seed=torch.random.seed())
model2 = AdaptableEstimator.load_from_checkpoint(ckpt, angle_dist=dist, seed=torch.random.seed())
print(model1.hparams.seed, model2.hparams.seed)
13813948499615692093 7102135249791225486
model1.setup()
model2.setup()
model1.trainingData.angles
tensor([1.3316, 0.9727, 2.6663, ..., 0.2046, 1.8856, 2.8803])
model2.trainingData.angles
tensor([1.5960, 1.6338, 2.3326, ..., 0.0220, 2.0367, 2.5341])
plt.imshow(model1.trainingData.images[0] - model2.trainingData.images[0])
<matplotlib.image.AxesImage at 0x171c5e820>
Certainly the data generated within the two models looks different.
Also, the initialization code seems to run upon loading the models.
I'm pretty confident that the seeds are different between the models.¶
Single iterates¶
for fi in concentrated_data[(concentrated_data['measurement'] == 'FI')]['data']:
plt.plot(fi)
It definitely looks more correlated than I would expect
colors = plt.cm.viridis(np.linspace(0,1,6))
dist = AngleDistribution( np.ones(500), [0., np.pi])
plt.plot(dist.bin_probs, c=colors[0])
i = 1
for row in concentrated_data[(concentrated_data.measurement == 'FI') &
(concentrated_data.iteration == 0)
]['data']:
new_values = dist.values / moving_average(row)**0.5
dist = AngleDistribution(new_values, [0., np.pi])
plt.plot(dist.bin_probs, c=colors[i])
i += 1
plt.title('Do all replicates in each iteration have the same structure?')
Text(0.5, 1.0, 'Do all replicates in each iteration have the same structure?')
Yep, this is the same behavior as across iterates: there is a surprising amount of shared structure between the replicates
colors = plt.cm.viridis(np.linspace(0,1,6))
dist = AngleDistribution( np.ones(500), [0., np.pi])
plt.plot(dist.bin_probs, c=colors[0])
i = 1
for row in concentrated_data[(concentrated_data.measurement == 'FI') &
(concentrated_data.iteration == 0)
]['data']:
new_values = dist.values / row**0.5
dist = AngleDistribution(new_values, [0., np.pi])
plt.plot(dist.bin_probs, c=colors[i])
i += 1
plt.title('What about the un-smoothed versions?')
Text(0.5, 1.0, 'What about the un-smoothed versions?')
I mean, yes. This is essentially the same as the behavior that we saw above.
So the smoothing is not the cause of the similarity.
Is it a result of the shared initialization?¶
In this way, initialization looks to be a sort of constraint on the network that is being pulled out of the noise?
Experiment: run different initializations, and see if the Fisher information of the learned networks is dependent on the initialization.
data= pd.read_pickle('experiment_result/ex6_initialization.pickle')
data
| rep | method | Fisher | |
|---|---|---|---|
| 0 | 0 | loaded | [7482.413039576917, 6917.13476382538, 6849.097... |
| 1 | 0 | loaded | [10166.779750083673, 8939.67076884833, 9921.07... |
| 2 | 0 | loaded | [8342.131900495337, 7754.1581290318945, 8720.9... |
| 3 | 0 | set state | [10246.708972247876, 10534.179180788991, 10474... |
| 4 | 0 | set state | [8194.859422943355, 9120.434044892028, 10009.6... |
| 5 | 0 | set state | [9891.966864509444, 10917.245571615525, 9565.3... |
| 6 | 1 | loaded | [9103.473100671537, 8093.2921956849395, 7738.5... |
| 7 | 1 | loaded | [8049.016757678765, 9203.03980910144, 10708.35... |
| 8 | 1 | loaded | [8168.780814261347, 7271.36786233274, 7416.672... |
| 9 | 1 | set state | [7484.556178987075, 7061.951097034991, 7524.17... |
| 10 | 1 | set state | [8361.456949634292, 9376.600646718349, 8269.81... |
| 11 | 1 | set state | [7888.846354406538, 8725.683978198615, 8975.97... |
| 12 | 2 | loaded | [7597.5959725953635, 8650.741227010567, 7785.3... |
| 13 | 2 | loaded | [8212.265420675509, 8202.753571510018, 9572.20... |
| 14 | 2 | loaded | [8401.989415222208, 8165.419969414646, 9289.19... |
| 15 | 2 | set state | [7875.954741235541, 8403.249508216057, 7851.75... |
| 16 | 2 | set state | [8102.391485970146, 8428.492809881152, 7607.05... |
| 17 | 2 | set state | [7750.2276935275795, 8183.110812613867, 9789.5... |
| 18 | 3 | loaded | [7888.236253710033, 7438.313531107031, 7802.09... |
| 19 | 3 | loaded | [9388.040978595753, 11182.55700286475, 10760.2... |
| 20 | 3 | loaded | [9358.530614150337, 8728.840107521382, 8636.65... |
| 21 | 3 | set state | [7974.953282728307, 8173.304388463969, 8817.06... |
| 22 | 3 | set state | [7547.519832147966, 7306.104821011238, 5886.70... |
| 23 | 3 | set state | [8097.5375863064755, 7504.368573390664, 7586.6... |
| 24 | 4 | loaded | [9767.59402579983, 8139.895565234726, 9224.221... |
| 25 | 4 | loaded | [8548.035607218688, 7886.105720024892, 7104.46... |
| 26 | 4 | loaded | [8347.773943916216, 9773.953499879903, 8465.08... |
| 27 | 4 | set state | [7509.38839085409, 7101.0802125403625, 8569.01... |
| 28 | 4 | set state | [9480.810211138492, 9434.11025135334, 11616.58... |
| 29 | 4 | set state | [9012.03739098086, 8643.40802404131, 8002.1764... |
| 30 | 5 | loaded | [8199.180526830847, 7921.528926425352, 7307.44... |
| 31 | 5 | loaded | [7655.612771660323, 7651.739397811222, 8399.38... |
| 32 | 5 | loaded | [7890.474416216101, 7784.39781415223, 7951.540... |
| 33 | 5 | set state | [9467.164749292322, 8209.064123106846, 9037.32... |
| 34 | 5 | set state | [8637.367718021858, 8000.276341570653, 8157.63... |
| 35 | 5 | set state | [7879.291543139089, 6833.589572237475, 8048.05... |
plt.subplots(3,2)
for row in data[data.method == 'loaded'].to_dict(orient="records"):
plt.subplot(3,2,row['rep']+1)
plt.plot(moving_average(row['Fisher']))
plt.subplots(3,2)
for row in data[data.method == 'set state'].to_dict(orient="records"):
plt.subplot(3,2,row['rep']+1)
plt.plot(moving_average(row['Fisher']))
Ok, its hard to tell by eye whether the series are more similar within bins than between bins
colors = plt.cm.viridis(np.linspace(0,1,7))
def group_std(x): return x.to_numpy().std()
xs = np.linspace(0, np.pi, 500)
i=1
for row in data[data.method == 'loaded'].groupby('rep'
).agg({'Fisher':['mean', group_std]}
).to_dict(orient="records"):
mean = row[('Fisher', 'mean')]
err = row[('Fisher','group_std')]
plt.plot(xs, mean, c=colors[i])
plt.plot(xs, mean-err, '--', c=colors[i])
plt.plot(xs, mean+err, '--', c=colors[i])
i+=1
plt.title('Replicate means and variances')
Text(0.5, 1.0, 'Replicate means and variances')
That visualization is worthless
sns.heatmap(np.corrcoef( np.array(data['Fisher'].to_list()) ))
plt.title('Pearson correlation heat map');
data
| rep | method | Fisher | |
|---|---|---|---|
| 0 | 0 | loaded | [7482.413039576917, 6917.13476382538, 6849.097... |
| 1 | 0 | loaded | [10166.779750083673, 8939.67076884833, 9921.07... |
| 2 | 0 | loaded | [8342.131900495337, 7754.1581290318945, 8720.9... |
| 3 | 0 | set state | [10246.708972247876, 10534.179180788991, 10474... |
| 4 | 0 | set state | [8194.859422943355, 9120.434044892028, 10009.6... |
| 5 | 0 | set state | [9891.966864509444, 10917.245571615525, 9565.3... |
| 6 | 1 | loaded | [9103.473100671537, 8093.2921956849395, 7738.5... |
| 7 | 1 | loaded | [8049.016757678765, 9203.03980910144, 10708.35... |
| 8 | 1 | loaded | [8168.780814261347, 7271.36786233274, 7416.672... |
| 9 | 1 | set state | [7484.556178987075, 7061.951097034991, 7524.17... |
| 10 | 1 | set state | [8361.456949634292, 9376.600646718349, 8269.81... |
| 11 | 1 | set state | [7888.846354406538, 8725.683978198615, 8975.97... |
| 12 | 2 | loaded | [7597.5959725953635, 8650.741227010567, 7785.3... |
| 13 | 2 | loaded | [8212.265420675509, 8202.753571510018, 9572.20... |
| 14 | 2 | loaded | [8401.989415222208, 8165.419969414646, 9289.19... |
| 15 | 2 | set state | [7875.954741235541, 8403.249508216057, 7851.75... |
| 16 | 2 | set state | [8102.391485970146, 8428.492809881152, 7607.05... |
| 17 | 2 | set state | [7750.2276935275795, 8183.110812613867, 9789.5... |
| 18 | 3 | loaded | [7888.236253710033, 7438.313531107031, 7802.09... |
| 19 | 3 | loaded | [9388.040978595753, 11182.55700286475, 10760.2... |
| 20 | 3 | loaded | [9358.530614150337, 8728.840107521382, 8636.65... |
| 21 | 3 | set state | [7974.953282728307, 8173.304388463969, 8817.06... |
| 22 | 3 | set state | [7547.519832147966, 7306.104821011238, 5886.70... |
| 23 | 3 | set state | [8097.5375863064755, 7504.368573390664, 7586.6... |
| 24 | 4 | loaded | [9767.59402579983, 8139.895565234726, 9224.221... |
| 25 | 4 | loaded | [8548.035607218688, 7886.105720024892, 7104.46... |
| 26 | 4 | loaded | [8347.773943916216, 9773.953499879903, 8465.08... |
| 27 | 4 | set state | [7509.38839085409, 7101.0802125403625, 8569.01... |
| 28 | 4 | set state | [9480.810211138492, 9434.11025135334, 11616.58... |
| 29 | 4 | set state | [9012.03739098086, 8643.40802404131, 8002.1764... |
| 30 | 5 | loaded | [8199.180526830847, 7921.528926425352, 7307.44... |
| 31 | 5 | loaded | [7655.612771660323, 7651.739397811222, 8399.38... |
| 32 | 5 | loaded | [7890.474416216101, 7784.39781415223, 7951.540... |
| 33 | 5 | set state | [9467.164749292322, 8209.064123106846, 9037.32... |
| 34 | 5 | set state | [8637.367718021858, 8000.276341570653, 8157.63... |
| 35 | 5 | set state | [7879.291543139089, 6833.589572237475, 8048.05... |
sns.heatmap(np.corrcoef( np.array(data['Fisher'].to_list()) ), vmin=-0.3, vmax=0.5)
plt.title('Pearson correlation heat map - zoom');
There doesn't really seem to be much structure in this measurement. We are looking for structure in the form of squares of length 3 or 6 along the diagonal.
Maybe I can convince myself that there is some such structure (eg 12-26, 7-10, 24-26), but it doesn't align with the the changes in initialization.
I conclude that either this is a poor measure, or there is no such initilization structure.
Positive control¶
concentrated_data = pd.read_pickle('trainedParameters/Exp6/concentrated/iterate_data.pickle')
uniform_data = pd.read_pickle('trainedParameters/Exp6/uniform/iterate_data.pickle')
untrained_data = pd.read_pickle('trainedParameters/Exp6/untrained/iterate_data.pickle')
# fix the indexing issues in the experiment script
inds =concentrated_data[ concentrated_data['measurement'] == 'probability'].index
concentrated_data.loc[inds, 'iteration'] = range(8)
uniform_data.loc[inds, 'iteration'] = range(8)
untrained_data.loc[inds, 'iteration'] = range(8)
plt.subplots(3,1)
for fi in concentrated_data[(concentrated_data['measurement'] == 'FI')]['data']:
plt.subplot(3,1,1)
plt.plot(fi)
plt.subplot(3,1,3)
plt.plot(fi)
for fi in uniform_data[(uniform_data['measurement'] == 'FI')]['data']:
plt.subplot(3,1,2)
plt.plot(fi)
plt.subplot(3,1,3)
plt.plot(fi)
By eye, the data seems to group into distinct curves. Is this captured by the mapping?
both = pd.concat([
concentrated_data[(concentrated_data['measurement'] == 'FI')],
uniform_data[(uniform_data['measurement'] == 'FI')]], axis=0)
sns.heatmap(np.corrcoef( np.array(both['data'].to_list()) ) )
<Axes: >
Yeah, ok. This is pretty clear.
(note that these are the unsmoothed trajectories!)
all_traj = pd.concat([
concentrated_data[(concentrated_data['measurement'] == 'FI')],
uniform_data[(uniform_data['measurement'] == 'FI')],
untrained_data[(untrained_data['measurement'] == 'FI')]
], axis=0)
sns.heatmap(np.corrcoef( np.array(all_traj['data'].to_list()) ) )
<Axes: >
The within-class correlation is greatly diminished when we look at the untrained initialization. In fact, it looks similar to the results from the second experiment.
for fi in untrained_data[(untrained_data['measurement'] == 'FI')]['data']:
plt.plot(fi)
Indeed, the untrained networks appear much more spread out than the pretrained ones!
However, this only shows up weakly in the Fisher information iteration
for row in untrained_data[untrained_data.measurement == 'FI'].groupby('iteration'
).agg({'data': 'mean'}
)['data']:
plt.plot(row)
from adaptableModel import AngleDistribution
colors = plt.cm.viridis(np.linspace(0,1,8))
for row in untrained_data[ untrained_data.measurement=='probability'].to_dict(orient="records"):
itr = row['iteration']
dist = AngleDistribution(row['data'], [0, np.pi])
plt.plot(np.linspace(0, np.pi, dist.npoints-1), dist.bin_probs, c=colors[itr])
plt.title('Untrained Network Trajectory');
dist = AngleDistribution(np.ones(500), [0, np.pi])
colors = plt.cm.viridis(np.linspace(0,1,8))
i=0
for row in untrained_data[untrained_data.measurement == 'FI'].groupby('iteration'
).agg({'data': 'mean'}
)['data']:
smoothed_fi = moving_average(row)
dist = AngleDistribution( dist.values / smoothed_fi, [0, np.pi])
plt.plot(dist.bin_probs, c=colors[i])
i+=1
plt.title('Untrained: recomputed bin probabilities');
What happens if I resample across iterates?
dist = AngleDistribution(np.ones(500), [0, np.pi])
colors = plt.cm.viridis(np.linspace(0,1,8))
to_sample = untrained_data[untrained_data.measurement == 'FI']
for iter in range(8):
inds = np.random.choice(range(40), 1)
sample_mean = np.mean(to_sample.iloc[inds]['data'].to_list(), axis=0)
smoothed_fi = moving_average(sample_mean)
dist = AngleDistribution( dist.values / smoothed_fi, [0, np.pi])
plt.plot(dist.bin_probs, c=colors[iter])
plt.title('Untrained: random trajectory per iterate');
The same behavior emerges when we sample a random subset of iterations, and even when we simply use the first set of iterations:
The probability distributions seem to diverge away from uniform
dist.values[0:10]
array([5.24937531e-32, 5.05582106e-32, 4.89830818e-32, 4.70749811e-32,
4.60765458e-32, 4.56129761e-32, 4.54834814e-32, 4.52298889e-32,
4.59708996e-32, 4.60470690e-32])
Oh boy, is it round off error??
dist = AngleDistribution(np.ones(500), [0, np.pi])
colors = plt.cm.viridis(np.linspace(0,1,8))
to_sample = untrained_data[untrained_data.measurement == 'FI']
for iter in range(8):
inds = np.random.choice(range(40), 1)
sample_mean = np.mean(to_sample.iloc[inds]['data'].to_list(), axis=0)
smoothed_fi = moving_average(sample_mean)
new_probs = dist.values / smoothed_fi
new_probs = new_probs / new_probs.sum()
dist = AngleDistribution(new_probs, [0, np.pi])
plt.plot(dist.bin_probs, c=colors[iter])
plt.title('Renormalized to avoid round-off');
Nope, thats still not it.
I should fix this though!
dist = AngleDistribution(np.ones(500), [0, np.pi])
colors = plt.cm.viridis(np.linspace(0,1,8))
to_sample = untrained_data[untrained_data.measurement == 'FI']
for iter in range(8):
sample_mean = 100*np.random.rand(500) +100
smoothed_fi = moving_average(sample_mean)
new_probs = dist.values / smoothed_fi
new_probs = new_probs / new_probs.sum()
dist = AngleDistribution(new_probs, [0, np.pi])
plt.plot(dist.bin_probs, c=colors[iter])
plt.title('Purely Random Fisher information')
Text(0.5, 1.0, 'Purely Random Fisher information')
mmmm. Using random data, the iterations still appear to grow away from zero, although not nearly as clearly.
This seems to be an instability in the iterations themselves.
dist = AngleDistribution(np.ones(500), [0, np.pi])
colors = plt.cm.viridis(np.linspace(0,1,8))
to_sample = untrained_data[untrained_data.measurement == 'FI']
plt.subplots(2,1)
vars = []
for iter in range(8):
sample_mean = 100*np.random.rand(500) +100
new_probs = dist.values / sample_mean
new_probs = new_probs / new_probs.sum()
dist = AngleDistribution(new_probs, [0, np.pi])
vars.append( dist.bin_probs.var() )
plt.subplot(2,1,1)
plt.plot(dist.bin_probs, c=colors[iter])
plt.title('Random - no smoothing')
plt.subplot(2,1,2)
plt.title('Variance')
plt.plot(vars, '.')
[<matplotlib.lines.Line2D at 0x170ab3fd0>]
dist1 = AngleDistribution(np.ones(500), [0, np.pi])
dist2 = AngleDistribution(np.ones(500), [0, np.pi])
colors = plt.cm.viridis(np.linspace(0,1,8))
to_sample = untrained_data[untrained_data.measurement == 'FI']
v1 = []
v2 = []
plt.subplots(3,1)
for iter in range(8):
sample_mean = 100*np.random.rand(500) +100
new_probs1 = dist1.values / sample_mean
new_probs1 = new_probs1 / new_probs1.sum()
dist1 = AngleDistribution(new_probs1, [0, np.pi])
v1.append( dist1.bin_probs.var() )
new_probs2 = dist2.values / moving_average(sample_mean)
new_probs2 = new_probs2 / new_probs2.sum()
dist2 = AngleDistribution(new_probs2, [0, np.pi])
v2.append( dist2.bin_probs.var() )
plt.subplot(3,1,1)
plt.title('No online smoothing - smoothed for plotting')
plt.plot(moving_average(dist1.bin_probs), c=colors[iter])
plt.subplot(3,1,2)
plt.title('With online smoothing')
plt.plot(dist2.bin_probs, c=colors[iter])
plt.subplot(3,1,1)
plt.suptitle('Head to head - smoothing or no');
plt.subplot(3,1,3)
plt.title('variance')
plt.plot(v1, '.', label='Not smoothed')
plt.plot(v2, '.', label='Smoothed')
plt.legend()
<matplotlib.legend.Legend at 0x171983f70>
Result¶
Ok, this is very illuminating: it looks like a combination of the online smoothing, combined with compounding noise that causes the divergence.
Noise¶
This is caused by a lack of mean-reversion: there is nothing pulling the distribution back toward uniform, so successive iterations simply diffuse further away on average. I was hoping that the networks themselves would do the mean reversion. Absent that, the diffusion is inevitable. This can be seen in the linearly increasing variance across iterates.
Smoothing¶
When we don't smooth, the iterations are all independent noisy samples, with linearly increasing variance. Smoothing introduces some serial dependence between the samples. By eye, this doesn't cause the deviation to be worse. It does make the variance increase more slowly ( no visible, but it is still increasing)
Important to dos:¶
- normalize the distribution values when they are updated to avoid round-off error. ✅
- investigate the uniform and concentrated networks: their trajectories are highly correlated.
- step 1: save weights only in the checkpoints ✅
- step 2: compare retraining with and without saving only the weights: does the results still hold? ✅
- Mean reversion scale of networks: what size of perturbations (width or height-wise) does the network revert?
Addressing the similarity of pre-trained networks¶
How did the hyperparameters change with iterations in the first experiment?¶
Maybe if the optimizer parameters shrink down to be really small, this would explain what we see.
model_data0 = torch.load('trainedParameters/Exp6/untrained/iter0/epoch=192-step=24704.ckpt',
map_location='cpu')
model_data1 = torch.load('trainedParameters/Exp6/untrained/iter1/epoch=149-step=19200.ckpt',
map_location='cpu')
model_data_init = torch.load('trainedParameters/Exp6/untrained.ckpt')
model_data1['optimizer_states'][0]
{'state': {0: {'step': tensor(19200.),
'exp_avg': tensor([[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]),
'exp_avg_sq': tensor([[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]])},
1: {'step': tensor(19200.),
'exp_avg': tensor([ 3.9950e-08, 2.2974e-08, 3.7015e-08, 2.8690e-08, -1.5791e-08,
-2.2150e-08, 1.8757e-08, -1.8802e-08, -4.8500e-08, 9.0094e-08,
4.2652e-09, 4.8591e-09, -1.3988e-08, 5.3582e-08, 5.0246e-08,
1.8181e-08, -5.6266e-08, 1.1453e-07, 1.7005e-07, -8.2164e-09,
-1.8664e-07, -1.5436e-09, -1.1229e-07, -8.8392e-08, -9.8638e-08,
2.1685e-07, -1.0704e-08, 2.6447e-08, 1.4891e-08, -3.2266e-08,
-4.2134e-08, 1.2371e-08, 6.0001e-08, -5.4728e-08, -2.7055e-08,
-3.1886e-09, 4.7419e-08, -8.4212e-08, 7.0370e-08, 3.9694e-08,
-2.9081e-08, -1.9754e-08, 5.9768e-08, -4.7940e-08, -8.3292e-08,
3.1447e-08, 2.9186e-07, 3.7023e-09, -9.7024e-08, -5.2214e-08,
-3.2593e-08, 4.0332e-09, -2.7107e-08, -7.3101e-08, -1.3235e-08,
-1.7992e-07, 2.5443e-07, 1.5715e-08, -2.3786e-08, 5.1012e-08,
-1.5511e-07, -2.4289e-08, -4.6096e-08, 1.9318e-08, -2.1846e-09,
1.5979e-08, 4.0604e-08, 3.1175e-08, 6.5061e-08, -1.9150e-08,
-8.3222e-09, 2.1977e-08, -1.8897e-08, -1.5616e-08, -5.8125e-09,
-4.1069e-08, -6.1176e-08, -1.8599e-08, -9.8040e-08, 2.1829e-08,
9.3983e-09, -4.0516e-08, -4.6373e-08, -7.3704e-08, -1.8467e-08,
-3.2379e-08, -6.0409e-08, -1.0234e-07, -6.5500e-08, -2.0492e-07,
-6.5358e-08, -5.3337e-09, 2.9970e-08, 3.5137e-09, 3.9394e-08,
8.7156e-08, -9.1030e-09, -1.9250e-08, -4.0935e-08, -9.0614e-08,
6.6389e-08]),
'exp_avg_sq': tensor([2.9767e-13, 1.1291e-13, 1.5908e-13, 8.0122e-14, 2.2190e-14, 4.9398e-13,
6.5349e-14, 3.8010e-14, 1.6380e-13, 3.9109e-13, 6.3192e-14, 1.0245e-13,
2.2673e-13, 1.9131e-13, 2.0234e-13, 1.4576e-13, 2.5629e-13, 3.0221e-13,
4.2077e-13, 4.0549e-14, 5.4309e-13, 4.8692e-13, 1.6738e-13, 4.7163e-13,
1.3034e-13, 4.8022e-13, 1.3200e-14, 5.7385e-13, 2.4600e-13, 1.2623e-13,
1.9026e-13, 3.7299e-13, 2.8063e-13, 1.9123e-13, 2.8781e-13, 5.7911e-13,
2.9629e-13, 1.8802e-13, 2.7304e-13, 3.4340e-13, 1.7665e-13, 1.9906e-13,
2.8867e-13, 8.1971e-14, 7.1792e-13, 2.4038e-13, 3.5841e-13, 4.3775e-13,
3.7669e-13, 2.9406e-13, 1.7000e-13, 5.8291e-14, 1.1526e-13, 1.9875e-13,
3.2290e-13, 6.8843e-13, 1.0436e-12, 6.7710e-14, 2.2433e-13, 5.9616e-13,
4.3518e-13, 4.7085e-13, 5.1738e-13, 1.7276e-13, 1.6865e-13, 1.8499e-13,
2.2188e-13, 4.4883e-13, 2.5631e-13, 7.5779e-14, 1.7237e-13, 9.1031e-13,
4.4974e-13, 1.0054e-13, 5.7251e-13, 9.2990e-14, 1.2516e-12, 6.2778e-14,
3.5808e-13, 3.7556e-14, 5.4941e-13, 1.0891e-13, 6.5609e-14, 2.7047e-13,
2.1465e-13, 1.4819e-14, 1.0981e-13, 1.1930e-13, 5.1540e-13, 4.6762e-13,
7.3987e-14, 5.0239e-13, 7.7537e-14, 1.1692e-13, 1.1469e-13, 2.5231e-13,
2.2377e-14, 1.9400e-13, 7.7623e-14, 2.7656e-13, 4.8752e-13])},
2: {'step': tensor(19200.),
'exp_avg': tensor([[ 3.9751e-07, 3.5196e-07, 1.0579e-06, ..., -1.1782e-06,
9.9092e-07, 1.2076e-06],
[-4.0587e-09, -7.7670e-08, -7.0465e-09, ..., -4.8705e-07,
-3.4182e-06, -1.1919e-08],
[ 5.7123e-06, 4.2821e-06, 4.2875e-06, ..., -2.8615e-06,
8.7897e-09, 4.1755e-06],
...,
[-8.2580e-07, 1.5665e-06, -1.2411e-06, ..., -3.0074e-07,
-8.0529e-07, -2.2139e-07],
[-9.0013e-06, 1.1886e-06, -5.8878e-08, ..., -1.5740e-06,
-1.2207e-06, -8.4139e-06],
[ 1.9409e-08, -4.9159e-09, -4.8037e-08, ..., -2.8185e-07,
1.8700e-07, 2.9006e-08]]),
'exp_avg_sq': tensor([[2.6704e-12, 2.2422e-12, 8.4092e-12, ..., 1.5337e-10, 6.1572e-09,
3.7416e-10],
[1.1404e-12, 1.3320e-12, 1.5067e-12, ..., 1.0863e-11, 4.8380e-10,
1.1343e-12],
[2.0850e-10, 1.9130e-09, 2.0834e-09, ..., 1.5120e-09, 6.5665e-12,
9.9315e-10],
...,
[1.9818e-09, 4.2680e-10, 2.3925e-09, ..., 9.5324e-11, 7.4941e-11,
1.4853e-09],
[1.2304e-09, 1.0797e-09, 1.7017e-09, ..., 1.1278e-10, 2.8729e-11,
6.6484e-10],
[2.7203e-13, 1.6071e-13, 4.1965e-13, ..., 5.4241e-12, 6.0056e-12,
2.7558e-13]])},
3: {'step': tensor(19200.),
'exp_avg': tensor([-7.0225e-08, -4.9133e-07, 3.5006e-07, -2.1929e-07, 6.0389e-08,
3.3106e-07, 1.8911e-07, 2.8392e-07, 3.5667e-07, -7.3169e-07,
2.7290e-07, 4.8043e-07, -1.6849e-07, 2.7576e-08, 6.6279e-07,
-2.6198e-07, 1.1365e-37, -2.5275e-08, -3.3905e-07, -4.7240e-08]),
'exp_avg_sq': tensor([9.8487e-12, 4.2503e-12, 1.7769e-11, 5.5501e-12, 5.3832e-13, 2.3923e-11,
7.1915e-13, 1.4934e-11, 2.8864e-12, 9.7026e-12, 4.1048e-11, 5.9378e-12,
1.2309e-11, 3.2376e-13, 5.4346e-12, 8.4505e-12, 3.2725e-14, 1.0349e-11,
4.4490e-12, 1.2267e-12])},
4: {'step': tensor(19200.),
'exp_avg': tensor([[ 7.4749e-05, 4.3184e-05, 9.7801e-05, -7.2232e-05, -7.2613e-06,
-2.7013e-05, -1.2499e-05, 3.3588e-04, 9.9370e-05, 2.2992e-04,
7.6647e-05, 1.9692e-04, -8.3657e-05, 1.0178e-06, -6.2391e-07,
-7.0467e-05, -1.1380e-37, -1.4542e-04, -4.0040e-05, 2.8469e-06],
[-2.1076e-06, -4.5587e-06, -6.6048e-05, 5.0901e-05, -9.4604e-06,
-7.1190e-05, -2.7243e-05, 1.1919e-04, 7.2453e-05, 2.5918e-05,
-1.8415e-04, 8.0264e-05, 1.5577e-05, -4.3151e-06, 1.1833e-04,
3.5031e-05, 1.1247e-37, 1.6499e-04, -5.2236e-05, -5.8078e-05]]),
'exp_avg_sq': tensor([[5.3103e-07, 1.3128e-07, 1.6914e-06, 2.5481e-07, 6.5761e-08, 1.6418e-06,
4.1097e-08, 2.3069e-06, 3.2923e-07, 1.0400e-06, 1.1325e-06, 6.9244e-07,
2.4119e-07, 4.3331e-08, 7.8954e-07, 5.3556e-07, 1.9832e-15, 5.1401e-07,
1.1419e-07, 1.1320e-08],
[1.7107e-06, 6.4376e-07, 4.6067e-07, 3.2384e-07, 2.4660e-08, 5.3090e-07,
5.7447e-08, 5.0411e-07, 2.6994e-07, 1.1136e-06, 2.7019e-06, 1.4653e-07,
6.5892e-07, 8.6931e-09, 4.0941e-07, 3.5560e-07, 2.4615e-14, 1.3871e-06,
1.0544e-06, 2.0632e-07]])},
5: {'step': tensor(19200.),
'exp_avg': tensor([ 4.1278e-06, -1.7324e-07]),
'exp_avg_sq': tensor([1.0909e-09, 1.1236e-09])}},
'param_groups': [{'lr': 0.001,
'betas': (0.9, 0.999),
'eps': 1e-08,
'weight_decay': 0,
'amsgrad': False,
'maximize': False,
'foreach': None,
'capturable': False,
'differentiable': False,
'fused': None,
'params': [0, 1, 2, 3, 4, 5]}]}
model_data0['optimizer_states'][0]
{'state': {0: {'step': tensor(24704.),
'exp_avg': tensor([[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]),
'exp_avg_sq': tensor([[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]])},
1: {'step': tensor(24704.),
'exp_avg': tensor([-3.3291e-09, -2.6901e-07, -8.5368e-09, -6.1689e-09, -5.8358e-10,
-4.2298e-08, 1.2142e-08, -8.4066e-08, -1.1084e-08, 6.6887e-08,
1.1023e-08, -3.7201e-08, -7.6836e-09, 2.2619e-08, -5.2994e-09,
-6.2287e-08, -9.4547e-10, -1.4335e-09, 3.8968e-08, 2.2513e-08,
2.2364e-09, -1.1143e-09, 3.6313e-08, -4.9238e-08, -9.3458e-09,
6.9197e-08, 1.5287e-09, 8.8791e-09, -4.9362e-08, -1.3478e-09,
-3.2514e-08, 2.3516e-09, 3.7128e-09, 8.7992e-08, -7.9905e-08,
-9.8589e-08, -1.2815e-08, 2.1365e-08, 1.5157e-08, -4.6912e-09,
-2.3465e-08, -1.3516e-08, 3.5132e-08, -7.1798e-09, 2.0962e-08,
2.3280e-08, -1.4391e-09, -3.2664e-09, 1.0576e-08, -2.7358e-08,
3.9048e-08, 4.3482e-08, -8.9489e-10, 9.2530e-09, 1.1110e-07,
9.3799e-09, -1.1647e-07, 4.5923e-10, 2.5281e-08, -4.9357e-10,
2.8122e-09, -1.6036e-08, -5.4464e-08, 3.9694e-09, -1.2456e-08,
-3.0809e-09, 6.5620e-08, -2.5560e-09, 2.5220e-09, -8.9181e-10,
-1.0684e-08, 1.1933e-12, -1.4719e-08, -3.6113e-08, -4.9516e-08,
1.0140e-08, 2.9423e-08, 1.9112e-08, 2.4743e-08, -3.2870e-09,
7.1025e-09, -2.6974e-10, -3.5040e-08, -1.4150e-08, 9.0055e-09,
4.8752e-08, 2.4106e-08, -6.5053e-09, 8.2710e-09, 4.9359e-09,
-3.2327e-08, 1.4464e-08, -1.0219e-10, 4.3338e-08, 1.3530e-09,
2.3589e-08, 6.5122e-08, -7.3957e-09, 1.0636e-08, 7.5874e-08,
6.5495e-08]),
'exp_avg_sq': tensor([3.3493e-14, 4.5770e-13, 2.7220e-15, 8.5458e-15, 8.8367e-14, 1.5491e-13,
1.1714e-14, 2.7347e-13, 5.3222e-15, 1.6002e-13, 3.1381e-15, 1.0114e-13,
7.7224e-15, 8.3938e-15, 7.0083e-14, 1.1019e-13, 1.1026e-14, 4.1827e-16,
1.5740e-13, 1.6618e-13, 5.8237e-14, 6.8468e-15, 3.6846e-14, 1.2065e-13,
2.5912e-14, 1.2317e-13, 1.8945e-15, 5.6746e-14, 1.4344e-13, 5.8940e-15,
7.5823e-14, 5.9772e-16, 1.3587e-13, 1.8097e-13, 8.6020e-14, 8.1786e-14,
5.7440e-14, 4.1861e-14, 2.8075e-15, 6.4289e-16, 3.3174e-14, 8.5489e-15,
4.2327e-14, 3.2144e-15, 1.7530e-13, 6.0951e-14, 7.1471e-15, 1.7221e-15,
3.5197e-13, 8.2597e-14, 1.0789e-13, 3.9281e-14, 1.4633e-15, 1.6018e-14,
2.5271e-13, 3.3406e-14, 2.1086e-13, 1.2904e-14, 8.8726e-14, 1.6301e-16,
1.5266e-13, 7.2683e-15, 3.6584e-14, 3.2395e-15, 2.5544e-15, 1.0577e-13,
2.3522e-13, 3.2286e-14, 1.4691e-14, 5.5773e-16, 2.3658e-14, 5.9801e-15,
8.8853e-15, 3.0933e-14, 5.9003e-14, 4.8041e-14, 1.8997e-14, 3.3076e-14,
3.7029e-14, 1.7951e-15, 1.5096e-14, 9.8980e-14, 3.4488e-13, 3.3717e-15,
7.9273e-14, 6.2698e-14, 3.8048e-14, 9.2425e-16, 1.9873e-14, 3.6993e-15,
5.3648e-14, 8.3630e-14, 4.9177e-15, 1.7690e-14, 7.0701e-14, 9.8439e-14,
8.8344e-14, 7.4604e-14, 3.3676e-14, 8.5277e-14, 2.5562e-13])},
2: {'step': tensor(24704.),
'exp_avg': tensor([[-4.9326e-08, 1.1744e-06, 4.6956e-07, ..., 6.5900e-07,
-1.2494e-06, -9.6769e-06],
[-1.0783e-37, 1.1338e-37, 1.0727e-37, ..., 1.1544e-37,
1.1560e-37, 1.1439e-37],
[-4.5249e-08, 5.8841e-06, 2.7241e-06, ..., 3.0290e-06,
2.4178e-07, 5.7643e-06],
...,
[ 3.2185e-06, 9.8147e-07, -8.6194e-08, ..., 1.5874e-06,
-1.3662e-06, 1.3513e-06],
[-1.6921e-06, -6.0535e-06, -4.8379e-07, ..., -9.9990e-07,
-1.9927e-08, 1.5375e-09],
[-4.0679e-10, 9.8492e-08, 7.0411e-10, ..., -8.5458e-07,
-6.8398e-07, -4.3051e-06]]),
'exp_avg_sq': tensor([[2.0310e-13, 4.3699e-12, 3.3615e-12, ..., 1.2814e-12, 1.6978e-09,
4.6638e-09],
[1.2633e-17, 3.5333e-15, 7.5548e-17, ..., 7.1894e-17, 6.3369e-18,
6.6643e-17],
[6.1298e-13, 4.1991e-09, 2.4885e-10, ..., 3.1910e-09, 8.7600e-11,
1.6563e-09],
...,
[2.5256e-10, 1.0209e-09, 1.4628e-11, ..., 2.9486e-10, 2.7519e-11,
6.9032e-10],
[1.6946e-10, 2.7459e-09, 1.0914e-10, ..., 2.5188e-11, 4.4105e-13,
2.9463e-12],
[8.0220e-14, 1.6896e-12, 9.2324e-13, ..., 1.1071e-10, 5.8465e-10,
2.4667e-10]])},
3: {'step': tensor(24704.),
'exp_avg': tensor([-1.7259e-07, 1.1149e-37, 5.9648e-07, -1.6012e-07, -4.0144e-08,
-1.4859e-07, 3.2276e-09, 7.5898e-08, -9.3215e-09, -1.5498e-07,
1.5582e-07, -9.4949e-08, -5.6716e-08, 8.2442e-08, 1.6503e-07,
-4.5831e-08, -1.3578e-07, 7.7463e-08, -9.9964e-08, -7.8850e-08]),
'exp_avg_sq': tensor([2.3174e-12, 1.1540e-15, 8.4657e-12, 1.0602e-12, 8.4984e-13, 3.4273e-12,
1.1004e-13, 1.8538e-12, 1.0229e-12, 4.6470e-12, 1.4870e-12, 7.9469e-13,
1.0878e-12, 1.0385e-12, 6.9255e-13, 6.6388e-13, 8.0358e-13, 9.1733e-13,
4.7585e-13, 3.2537e-13])},
4: {'step': tensor(24704.),
'exp_avg': tensor([[ 1.3879e-05, -1.0978e-37, 3.1283e-06, -3.2041e-05, -8.3746e-05,
-5.4410e-05, -2.0653e-06, -6.0644e-05, 1.2586e-05, 5.4365e-05,
-2.8233e-05, -2.4750e-05, -3.0203e-05, -7.5921e-05, 1.1861e-05,
-1.0003e-05, 5.6459e-06, 2.1341e-05, -1.3047e-05, -8.5711e-05],
[-1.9475e-04, -1.0633e-37, -1.5675e-04, 4.5946e-05, -5.8112e-05,
8.7905e-06, 4.8987e-05, -1.8739e-05, 5.5974e-06, 2.7959e-05,
-1.6353e-05, -2.4068e-05, 9.0900e-06, -2.8180e-05, 2.1717e-05,
3.6198e-05, -4.8609e-05, 4.2364e-05, 7.9464e-05, -3.7461e-05]]),
'exp_avg_sq': tensor([[5.7826e-08, 8.0786e-16, 8.8738e-07, 8.2365e-08, 2.7641e-07, 2.6401e-07,
6.2966e-09, 5.2752e-07, 5.1849e-07, 1.7811e-06, 7.1370e-07, 3.9264e-07,
1.1426e-07, 4.0743e-07, 2.1667e-07, 3.9353e-08, 3.2557e-09, 1.7091e-07,
3.6437e-08, 2.6345e-07],
[1.0502e-06, 7.0775e-16, 1.4466e-06, 2.3260e-07, 1.5807e-07, 1.5568e-07,
7.9575e-08, 1.5304e-07, 1.7943e-07, 2.6096e-06, 1.5657e-07, 2.4886e-08,
1.2239e-07, 3.9524e-08, 3.6080e-07, 1.9693e-07, 4.9994e-08, 2.4139e-07,
3.6219e-07, 3.9706e-08]])},
5: {'step': tensor(24704.),
'exp_avg': tensor([-6.4559e-07, -1.2125e-06]),
'exp_avg_sq': tensor([2.0056e-10, 2.9153e-10])}},
'param_groups': [{'lr': 0.001,
'betas': (0.9, 0.999),
'eps': 1e-08,
'weight_decay': 0,
'amsgrad': False,
'maximize': False,
'foreach': None,
'capturable': False,
'differentiable': False,
'fused': None,
'params': [0, 1, 2, 3, 4, 5]}]}
model_data_init['optimizer_states'][0]['param_groups']
[{'lr': 0.001,
'betas': (0.9, 0.999),
'eps': 1e-08,
'weight_decay': 0,
'amsgrad': False,
'maximize': False,
'foreach': None,
'capturable': False,
'differentiable': False,
'fused': None,
'params': [0, 1, 2, 3, 4, 5]}]
Hard to get much out of these numbers. The basic parameters aren't changing much.
Are the results different if we save only the weights during pretraining?¶
Experiment: initialize models, pretrain saving either weights only or not, compare.
init_data = pd.read_pickle('experiment_result/ex6_init2.pickle')
init_data.sort_values(['method', 'rep'])
| rep | method | Fisher | |
|---|---|---|---|
| 8 | 0 | all | [9731.43904304492, 8632.25088476819, 9236.8925... |
| 9 | 0 | all | [9154.476091599896, 9358.899765289718, 7723.10... |
| 10 | 0 | all | [7900.043777763896, 6880.582472615025, 7635.87... |
| 11 | 0 | all | [7668.709673804187, 8520.00988880094, 9665.104... |
| 12 | 0 | all | [8073.243401894655, 7885.193828813932, 8202.21... |
| 13 | 0 | all | [9092.411254991597, 9974.333940118238, 8498.38... |
| 22 | 1 | all | [4716.853426608467, 5987.776276603933, 4814.12... |
| 23 | 1 | all | [5251.742508812979, 5046.3991082111115, 5478.3... |
| 24 | 1 | all | [6491.01862820697, 5418.43973423646, 5068.2023... |
| 25 | 1 | all | [5574.396421648473, 6090.717359568154, 5532.34... |
| 26 | 1 | all | [6283.682062739737, 5033.9707862196365, 5257.4... |
| 27 | 1 | all | [5736.774967734029, 5660.824241487151, 5224.84... |
| 36 | 2 | all | [7245.134008848942, 7376.052201136431, 7598.22... |
| 37 | 2 | all | [5821.538205425657, 5957.800981747067, 6563.98... |
| 38 | 2 | all | [6823.262177726626, 7335.203080726032, 6860.93... |
| 39 | 2 | all | [7724.908343458281, 6320.398675279556, 8219.12... |
| 40 | 2 | all | [8775.635508300766, 8271.719569157684, 8201.03... |
| 41 | 2 | all | [7256.410774290576, 7065.2717751842865, 7735.6... |
| 0 | 0 | init | [94.97747642708468, 109.0764109552019, 128.875... |
| 7 | 0 | init | [8.9423179738955, 9.557862311460575, 11.175286... |
| 14 | 1 | init | [63.87981372499806, 64.30695170332146, 64.7981... |
| 21 | 1 | init | [67.886799586104, 70.30323071279315, 68.704659... |
| 28 | 2 | init | [0.34414571095275365, 0.07857253573667755, 0.1... |
| 35 | 2 | init | [11.015241842309006, 15.101778242178428, 12.58... |
| 1 | 0 | weights | [7489.937823357886, 7298.444890711764, 6808.25... |
| 2 | 0 | weights | [6866.899285417158, 6418.306668619433, 7346.95... |
| 3 | 0 | weights | [6402.598569271286, 6265.5348798878495, 6175.5... |
| 4 | 0 | weights | [9434.350410260768, 7725.35230058199, 8216.120... |
| 5 | 0 | weights | [7756.657083003193, 6857.583918873309, 8427.94... |
| 6 | 0 | weights | [8042.877249508951, 7283.056005857774, 7762.27... |
| 15 | 1 | weights | [8534.682269248882, 9713.9064469059, 7807.6463... |
| 16 | 1 | weights | [9984.304641736942, 7825.095072366619, 8354.00... |
| 17 | 1 | weights | [11365.563566082628, 11240.831909742921, 9315.... |
| 18 | 1 | weights | [8846.371134361654, 9038.215308570376, 8712.06... |
| 19 | 1 | weights | [9737.283185256647, 8907.589827069598, 9443.63... |
| 20 | 1 | weights | [9547.901531240255, 10699.910357082279, 9155.6... |
| 29 | 2 | weights | [7360.646468316156, 7404.83808889236, 6220.085... |
| 30 | 2 | weights | [6257.006131289272, 7131.237135011165, 6421.76... |
| 31 | 2 | weights | [6734.224614825596, 5947.305354442145, 6991.21... |
| 32 | 2 | weights | [7582.8126041177775, 6751.791311037392, 7093.8... |
| 33 | 2 | weights | [6073.9006480561175, 6102.018964492787, 6490.7... |
| 34 | 2 | weights | [5292.418137873429, 6522.818005062166, 6758.61... |
plt.title('Cov - All, init, weights')
sns.heatmap(np.corrcoef( np.array(init_data.sort_values(['method', 'rep'])['Fisher'].to_list() )))
<Axes: title={'center': 'Cov - All, init, weights'}>
initial_runs = np.array(init_data[ init_data.method == 'init']['Fisher'].to_list())
figsize(4,3.2)
plt.title('initial runs - removing mean')
sns.heatmap(np.corrcoef( initial_runs - initial_runs.mean(0) ))
<Axes: title={'center': 'initial runs - removing mean'}>
colors = plt.cm.viridis(np.linspace(0,1,3))
figsize(20, 20)
plt.subplots(3,1)
plt.subplot(3,1,1)
plt.title('Initialization')
for row in init_data[ init_data.method == 'init'].to_dict(orient="records"):
plt.plot(row['Fisher'], c=colors[row['rep']])
plt.subplot(3,1,2)
plt.title('All')
for row in init_data[ init_data.method == 'all'].to_dict(orient="records"):
plt.plot(row['Fisher'], c=colors[row['rep']])
plt.subplot(3,1,3)
plt.title('weights')
for row in init_data[ init_data.method == 'weights'].to_dict(orient="records"):
plt.plot(row['Fisher'], c=colors[row['rep']])
Ok. These results are very interesting¶
Fisher curves with the same initialization are correlated regardless of how the model is saved.
- this can be seen in the curves themselves, more similar within colors than between
- it is also shown in correlation plot by the squares along the diagonal, spaced 6 each.
The correlation in the initilization curves is due to the non-uniform distribution: removing the mean removes the correlation for the most part.
- However, it does introduce negative correlation, presumably because the mean is computed from these trajectories themselves.
The weight-only curves are very correlated across initializations, and very correlated to the Fisher information in initialization runs.
- This hold regardless of whether the initilization run corresponded to the particular weight-only curve in question, which suggests that it is the result of failure to remove the initial Fisher information trend during the weight-only training.
copy = init_data[init_data.method == 'init']
copy.reset_index()
mean = copy['Fisher'].mean()
mean_sub = copy.apply(lambda r: r['Fisher'] - mean, axis=1)
method_sub = copy.apply(lambda r: 'init_mean_rm', axis=1)
copy.loc[:,'Fisher'] = mean_sub
copy.loc[:,'method'] = method_sub
init_data_aug = pd.concat([init_data, copy])
plt.title('Cov - including mean removal')
sns.heatmap(np.corrcoef( np.array(init_data_aug.sort_values(['method', 'rep'])['Fisher'].to_list() )))
<Axes: title={'center': 'Cov - including mean removal'}>
Introducing mean removal¶
again, there is a sizable degree of negative correlation
The mean removed versions seem to show little correlation to either the weight-only or the all-parameter models.
However, they do pick-out quite well the initialization run that each of the re-training results is based on.
Initialization conclusions:¶
There is, in fact, a sizable impact of initialization on the networks that are learned. This holds regardless of how the network is saved and reloaded. Saving network and training parameters beyond the weights seems to make the retraining more effective at removing general trends from the network, but the other initialization effects remain.
On one hand, this is exactly the type of effect that I was hoping to see: there is a residual of the pre-training that can be dected. It is very interesting that saving the state of the trainer is sufficient to remove it.
On the other hand, there remains initialization dependent correlation. Perhaps this is not altogether that surprising: the network doesn't need to unlearn quirks of the initialization if these don't impact generalization ability of the network. In a network that generalizes very well after the initial training, we wouldn't expect the second round of training to do much, if anything, with data from the same distribution, freshly sampled.